import jax
import numpy as np
import numpy as onp

from collections import defaultdict

from matplotlib import pylab as plt



def interpolation(task, key, param_list, test_indices, test_ds, save_dir):

    def acc(param):
        def _acc(param, test_indices):
            batch = jax.tree_util.tree_map(lambda x: x[test_indices], test_ds)
            logits_ = jax.nn.log_softmax(task.logit(param, key, batch))
            corrects = (logits_.argmax(-1) == batch['label']).mean()
            return corrects

        accs = jax.vmap(_acc, in_axes=(None, 0))(param, test_indices)
        return accs.mean().item()

    lamb = np.linspace(0, 1, 10)
    test_error_list = []

    tmp_param_list = param_list
    for p in range(len(tmp_param_list)):
        param1, param2 = tmp_param_list[p], tmp_param_list[p+1]
        for i in lamb:
            param = jax.tree_util.tree_map(lambda x, y: (1-i) * x + i * y, param1, param2)
            test_error = (1.0 - acc(param)) * 100
            test_error_list.append(test_error)
        if p == (len(tmp_param_list)-2):
            break
    
    np.save(f'{save_dir}/test_error', np.array(test_error_list))
    markers_on = np.linspace(0, len(test_error_list)-1, len(param_list)).astype(int)
    ticks_labels = ['1/0'] * (len(markers_on)-2)
    ticks_labels.append('1')
    ticks_labels.insert(0, '0')
    plt.figure(figsize=(5, 4))
    plt.plot(test_error_list, markevery=markers_on, marker='o')
    plt.ylabel('Test error (%)')
    plt.xlabel('Mixing coefficient $\lambda$')
    plt.xticks(markers_on, labels=ticks_labels)
    plt.savefig(f'{save_dir}/interpolation.png')

    return 